import os
import torch
import transformers

from torch import nn
from torch.profiler import profile, record_function, ProfilerActivity
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

TRAINING_ARGS_NAME = "training_args.bin"

class PeftTrainer(transformers.Trainer):
    # def _save(self, output_dir: Optional[str] = None, state_dict=None):
    #     output_dir = output_dir if output_dir is not None else self.args.output_dir
    #     os.makedirs(output_dir, exist_ok=True)
    #     print("Saving model checkpoint to ",output_dir)
    #     """
    #     logger.info(f"Saving model checkpoint to {output_dir}")
    #     supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
    #     # Save a trained model and configuration using `save_pretrained()`.
    #     # They can then be reloaded using `from_pretrained()`
    #     if not isinstance(self.model, supported_classes):
    #         if state_dict is None:
    #             state_dict = self.model.state_dict()

    #         if isinstance(self.accelerator.unwrap_model(self.model), supported_classes):
    #             self.accelerator.unwrap_model(self.model).save_pretrained(
    #                 output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
    #             )
    #         else:
    #             logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
    #             if self.args.save_safetensors:
    #                 safetensors.torch.save_file(
    #                     state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
    #                 )
    #             else:
    #                 torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
    #     else:
    #         self.model.save_pretrained(
    #             output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
    #         )

    #     if self.tokenizer is not None:
    #         self.tokenizer.save_pretrained(output_dir)
    #     """

    #     self.model.save_pretrained(
    #         output_dir, # state_dict=state_dict, safe_serialization=self.args.save_safetensors
    #     )
    #     # Good practice: save your training arguments together with the trained model
    #     torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
    
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:   
        with profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], 
            profile_memory=True,
            use_cuda=True,
            with_stack=True,
            record_shapes=True,
        ) as prof:
            re = super().training_step(model, inputs)

        prof.export_chrome_trace("model_trace.json")

        return re
    
'''
class PeftSavingCallback(TrainerCallback):  
    """ Correctly save PEFT model and not full model """
    def _save(self, model, folder):
        peft_model_path = os.path.join(folder, "adapter_model")
        model.save_pretrained(peft_model_path)

    def on_train_end(self, args: TrainingArguments, state: TrainerState,
            control: TrainerControl, **kwargs):
        """ Save final best model adapter """
        self._save(kwargs['model'], state.best_model_checkpoint)

    def on_epoch_end(self, args: TrainingArguments, state: TrainerState,
            control: TrainerControl, **kwargs):
        """ Save intermediate model adapters in case of interrupted training """
        folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
        self._save(kwargs['model'], folder)
'''